{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['HTTP_PROXY'] = 'http://127.0.0.1:7890'\n",
    "os.environ['HTTPS_PROXY'] = 'http://127.0.0.1:7890'\n",
    "\n",
    "import json\n",
    "import random\n",
    "from datasets import load_dataset\n",
    "\n",
    "from settings import LLM_CLS_FORMAT, HUMAN_FORMAT, categories, LABELS_DICT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_abstract = load_dataset(\"ccdv/patent-classification\", \"abstract\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 原始数据集构造"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['text', 'label'],\n",
       "        num_rows: 25000\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['text', 'label'],\n",
       "        num_rows: 5000\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['text', 'label'],\n",
       "        num_rows: 5000\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ds_abstract"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_LLM(dataset, output_file, num=-1, shuffle=False):\n",
    "    res = []\n",
    "    for item in dataset:\n",
    "        text = item[\"text\"]\n",
    "        label = item[\"label\"]\n",
    "        llm_data = {\n",
    "                'instruction': LLM_CLS_FORMAT.format(categories=\"\\n\".join(categories)),\n",
    "                \"input\":HUMAN_FORMAT.format(input=text),\n",
    "                \"output\": LABELS_DICT[label]\n",
    "            }\n",
    "        res.append(llm_data)\n",
    "    \n",
    "    if shuffle:\n",
    "        random.shuffle(res)\n",
    "    \n",
    "    if num > 0:\n",
    "        res = res[:num]\n",
    "        \n",
    "    with open(output_file, 'w') as w:\n",
    "        w.write(json.dumps(res, ensure_ascii=False, indent=2) + '\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "for num in [100, 500, 1000, 2000]:\n",
    "    convert_LLM(\n",
    "        dataset = ds_abstract[\"train\"], \n",
    "        output_file = f\"data/llm_train_{num}.json\",\n",
    "        num=num,\n",
    "        shuffle=True\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "convert_LLM(ds_abstract[\"validation\"], \"data/llm_valid.json\")\n",
    "convert_LLM(ds_abstract[\"test\"], \"data/llm_test.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
